R"""
"""
#
import numpy as onp
import numpy.typing as onpt
import time
import torch
import torch.nn as nn
import more_itertools as xitertools
from typing import List, Tuple, Optional, Union, cast
from ..indexable import FrameworkIndexable
from ...meta.dyngraph.sparse.staedge import DynamicAdjacencyListStaticEdge
from ...meta.dyngraph.sparse.dynedge import DynamicAdjacencyListDynamicEdge
from ..types import TIMECOST
from ..transfer import transfer
from ...meta.batch import batchize, batchize2
from lib.model.activate import activatize
import torch.optim as optim
from dgl import DGLGraph
import dgl
from sklearn import linear_model
from sklearn.utils.extmath import randomized_svd
import scipy.sparse as sp
import lib.utils.ppr as ppr
import numpy as onp
import networkx as nx
import os
import csv

class Static(torch.nn.Module):
    R"""
    Treate static feature as dynamic.
    """
    def forward(
        self,
        tensor: torch.Tensor,
        /,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        R"""
        Forward.
        """
        #
        return (torch.reshape(tensor, (1, *tensor.shape)), tensor)
    

class FrameworkSdgnn(
    FrameworkIndexable[
        Union[DynamicAdjacencyListStaticEdge, DynamicAdjacencyListDynamicEdge],
    ],
):
    R"""
    Framework with dynamic graph meta samples.
    """
    #
    BATCH_PAD = True
    save_measurement = True


    def freeze_target_model(self):
        for param in self.neuralnet.tgnn.snn_node.parameters():
            param.requires_grad = False
        for param in self.neuralnet.tgnn.gnnx2.parameters():
            param.requires_grad = False
        for param in self.neuralnet.mlp.parameters():
            param.requires_grad = False

    
    def prepare_sdgnn(self, lr, weight_decay):

        self.neuralnet.transformation_model = self.neuralnet.transformation_model.to(self.device)

        self.optimizer_sdgnn = optim.Adam(
            self.neuralnet.transformation_model.parameters(), lr=lr, weight_decay=weight_decay
        )
        idx = onp.arange(self.metaset.num_nodes)
        # print("self.metaset.num_nodes: ", idx)
        # Create the adjacency matrix in COO format
        data = onp.ones(self.metaset.edge_srcs.shape[0])  # Use 1 for the presence of an edge
        adj_matrix_coo = sp.coo_matrix((data, (self.metaset.edge_srcs, self.metaset.edge_dsts)),\
                                        shape=(self.metaset.num_nodes, self.metaset.num_nodes))

        # Convert it to CSR format
        adj_matrix_csr = adj_matrix_coo.tocsr()

        graph_nx = nx.from_scipy_sparse_array(adj_matrix_csr, create_using=nx.MultiGraph)
        # save to GraphML file
        saved_nx_graph_dir = os.path.join(self.dir_this_time,"graph.graphml")
        nx.write_graphml(graph_nx, saved_nx_graph_dir)
        print("Saving graph to ", saved_nx_graph_dir)

        topk_ppr_matrix = ppr.topk_ppr_matrix(adj_matrix=adj_matrix_csr, alpha=0.5, eps=1e-4, idx=idx, topk=64)
        # print(topk_ppr_matrix.toarray().shape)
        self.weights = topk_ppr_matrix.toarray()
        self.loss_func = nn.MSELoss()
        self.freeze_target_model()
        self.optimizer_mlp = optim.Adam(self.neuralnet.mlp.parameters(), lr=lr, weight_decay=weight_decay)
    
    def get_snn_embedding(self, node_feats):
        self.neuralnet.tgnn.snn_node.eval()
        with torch.no_grad():
            (snn_embeds, _) = (
                self.neuralnet.tgnn.snn_node.forward(torch.permute(node_feats, (2, 0, 1)))
            )
        # num of nodes x batch (snapshot) x embed_size
        return snn_embeds


    def get_node_embedding(self,
        edge_tuples: torch.Tensor, edge_feats: torch.Tensor,
        edge_labels: torch.Tensor, edge_ranges: torch.Tensor,
        edge_times: torch.Tensor, node_feats: torch.Tensor,
        node_labels: torch.Tensor, node_times: torch.Tensor,
        node_masks: torch.Tensor,):
        # In sequence-then-graph flow, dynamic edges are already aggregated
        # together and all steps has exactly the same aggregated edge data, and
        # we will only use data from last step.
        self.neuralnet.tgnn.snn_node.eval()
        self.neuralnet.tgnn.gnnx2.eval()
        total_edges = 0
        with torch.no_grad():
            (edge_embeds, _) = self.neuralnet.tgnn.snn_edge.forward(edge_feats)
            total_edges = total_edges + len(edge_embeds[-1])
            # Take only embedding from the last step.
            # The graph convolution will also use connectivies from last step.
            
            (node_embeds, _) = (
                self.neuralnet.tgnn.snn_node.forward(torch.permute(node_feats, (2, 0, 1)))
            )
                    #
            node_embeds = self.neuralnet.tgnn.gnnx2.forward(
                edge_tuples, edge_embeds[-1], self.neuralnet.tgnn.activate(node_embeds[-1]),
            )
        return node_embeds, node_feats, node_times, node_masks


    def nodesplit_masks(
        self,
        meta_indices: List[int], meta_batch_size: int,
        /,
    ) -> onpt.NDArray[onp.generic]:
        R"""
        Translate given metaset indices into available node indices.
        For node pindle, only given metaset indices will be used.
        For time pindle, all node indices will be used.
        """
        #
        if self.metaspindle == "node":
            # If spindle is node, only nodes of given indices will be
            # available.
            masks_numpy = (
                onp.zeros((self.metaset.num_nodes,)).astype(onp.int64)
            )
            masks_numpy[meta_indices] = 1
            masks_numpy = onp.tile(masks_numpy, (meta_batch_size,))
        else:
            # Otherwise, all nodes are avaiable.
            masks_numpy = (
                onp.ones((meta_batch_size * self.metaset.num_nodes,))
                .astype(onp.int64)
            )
        return masks_numpy

    def set_node_batching(self, with_edge: bool, /) -> None:
        R"""
        Set batch construction.
        """
        #
        self.node_batch_with_edge = with_edge

    def node_batch(
        self,
        meta_indices: List[int], meta_index_pad: Optional[int],
        meta_batch_size: int,
        /,
    ) -> List[onpt.NDArray[onp.generic]]:
        R"""
        Construct a batch by node data.
        """
        # Fill the batch to constant batch size by given padding index.
        # Create a mask over samples to filter padded ones in later usage.
        # Expand and apply the same mask as samples to all nodes in
        # corresponding samples.
        masks_numpy = onp.zeros((meta_batch_size,)).astype(onp.int64)
        masks_numpy[:len(meta_indices)] = 1
        masks_numpy = onp.repeat(masks_numpy, self.metaset.num_nodes, axis=0)

        # Get memory.
        if self.node_batch_with_edge:
            #
            (memory_input_numpy, memory_target_numpy) = batchize2(
                self.metaset, meta_indices, meta_index_pad, meta_batch_size,
            )
        else:
            #
            (memory_input_numpy, memory_target_numpy) = batchize(
                self.metaset, meta_indices, meta_index_pad, meta_batch_size,
            )
        return [masks_numpy, *memory_input_numpy, *memory_target_numpy]

    def train(
        self,
        meta_indices: List[int], meta_index_pad: Optional[int],
        meta_batch_size: int, pinned: List[torch.Tensor],
        /,
    ) -> TIMECOST:
        R"""
        Train.
        Mostly used for neural network parameter tuning.
        """
        #
        timeparts: TIMECOST

        #
        timeparts = {}

        #
        elapsed = time.time()
        masks_nodesplit_numpy = (
            self.nodesplit_masks(meta_indices, meta_batch_size)
        )
        timeparts["generate"] = [time.time() - elapsed]
        elapsed = time.time()
        (masks_nodesplit_ondev,) = (
            transfer([masks_nodesplit_numpy], self.device)
        )
        timeparts["transfer"] = [time.time() - elapsed]

        # If we split data by node, given indices is indeed transductive node
        # indices which has been converted into a mask array before.
        # Thus, we will batch over the full metaset.
        # Otherwise, we only batch over metaset of given meta indices.
        if self.metaspindle == "node":
            #
            batch_indices = list(range(len(self.metaset)))
        else:
            #
            batch_indices = meta_indices

        #
        timeparts["forward"] = []
        timeparts["backward"] = []
        for batch in xitertools.chunked(batch_indices, meta_batch_size):
            # Batchize only nodes of batch graphs.
            elapsed = time.time()
            memory_node_numpy = (
                self.node_batch(list(batch), meta_index_pad, meta_batch_size)
            )
            cast(List[float], timeparts["generate"]).append(
                time.time() - elapsed,
            )

            # Node mask need special processing.
            elapsed = time.time()
            (masks_hole_ondev, *memory_node_ondev) = (
                transfer(memory_node_numpy, self.device)
            )
            cast(List[float], timeparts["transfer"]).append(
                time.time() - elapsed,
            )
            node_masks_ondev = masks_hole_ondev * masks_nodesplit_ondev

            # Rearange and reshape device memory tensors to fit task
            # requirements.
            (memory_input_ondev, memory_target_ondev) = (
                self.neuralnet.reshape(
                    pinned, memory_node_ondev, node_masks_ondev,
                )
            )

            # Forward.
            elapsed = time.time()
            memory_output_ondev = self.neuralnet.forward(*memory_input_ondev)
            cast(List[float], timeparts["forward"]).append(
                time.time() - elapsed,
            )

            # Backward.
            elapsed = time.time()
            if self.neuralnet.num_resetted_params > 0:
                #
                self.optim.zero_grad()
                self.neuralnet.sidestep(
                    *memory_output_ondev, *memory_target_ondev,
                )
                self.gradclip(self.neuralnet, 1.0)
                self.optim.step()
            cast(List[float], timeparts["backward"]).append(
                time.time() - elapsed,
            )
        return timeparts

    def evaluate(
        self,
        meta_indices: List[int], meta_index_pad: Optional[int],
        meta_batch_size: int, pinned: List[torch.Tensor],
        /,
    ) -> Tuple[List[float], TIMECOST]:
        R"""
        Evaluate.
        Mostly used for neural network parameter evaluation.
        """
        #
        timeparts: TIMECOST

        #
        timeparts = {}
        estimates = []
        #
        elapsed = time.time()
        masks_nodesplit_numpy = (
            self.nodesplit_masks(meta_indices, meta_batch_size)
        )
        timeparts["generate"] = []
        elapsed = time.time()
        (masks_nodesplit_ondev,) = (
            transfer([masks_nodesplit_numpy], self.device)
        )
        timeparts["transfer"] = [time.time() - elapsed]

        # If we split data by node, given indices is indeed transductive node
        # indices which has been converted into a mask array before.
        # Thus, we will batch over the full metaset.
        # Otherwise, we only batch over metaset of given meta indices.
        if self.metaspindle == "node":
            #
            batch_indices = list(range(len(self.metaset)))
        else:
            #
            batch_indices = meta_indices

        #
        timeparts["forward"] = []
        # \\ cnt = 0
        # \\ self.neuralnet.SEE_EMBEDS = True
        for batch in xitertools.chunked(batch_indices, meta_batch_size):
            # Batchize only nodes of batch graphs.
            elapsed = time.time()
            memory_node_numpy = (
                self.node_batch(list(batch), meta_index_pad, meta_batch_size)
            )
            cast(List[float], timeparts["generate"]).append(
                time.time() - elapsed,
            )

            # Node mask need special processing.
            elapsed = time.time()
            (masks_hole_ondev, *memory_node_ondev) = (
                transfer(memory_node_numpy, self.device)
            )
            cast(List[float], timeparts["transfer"]).append(
                time.time() - elapsed,
            )
            node_masks_ondev = masks_hole_ondev * masks_nodesplit_ondev

            # Rearange and reshape device memory tensors to fit task
            # requirements.
            (memory_input_ondev, memory_target_ondev) = (
                self.neuralnet.reshape(
                    pinned, memory_node_ondev, node_masks_ondev,
                )
            )

            # Forward.
            elapsed = time.time()
            memory_output_ondev = self.neuralnet.forward(*memory_input_ondev)
            cast(List[float], timeparts["forward"]).append(
                time.time() - elapsed,
            )

            # Performance metrics.
            estimates.append(
                self.neuralnet.metrics(
                    *memory_output_ondev, *memory_target_ondev,
                ),
            )

        return (
            [
                sum(measure for (_, measure) in record)
                / sum(size for (size, _) in record)
                for record in (
                    [list(metric) for metric in xitertools.unzip(estimates)]
                )
            ],
            timeparts,
        )
    

    def init_sparse_tensor(self, n=307):
        # Define the size of the sparse matrix
        n=n

        # Create an empty graph
        g = DGLGraph()

        # Add nodes to the graph
        g.add_nodes(n)

        # Add some edges to the graph to define the sparsity pattern
        # For demonstration, we'll add edges between a few nodes
        # In practice, you'd add edges based on your specific data
        edges_src = [0, 1, 2, 3, 4]
        edges_dst = [1, 2, 3, 4, 5]
        g.add_edges(edges_src, edges_dst)

        # Get the adjacency matrix in COO format
        adj_matrix = g.adj_external(scipy_fmt='coo')

        # Convert the adjacency matrix to a PyTorch sparse tensor
        indices = torch.tensor([adj_matrix.row, adj_matrix.col], dtype=torch.long)
        values = torch.tensor(adj_matrix.data, dtype=torch.float)
        sparse_tensor = torch.sparse_coo_tensor(indices, values, (n, n))
        # Print the sparse tensor
        self.sparse_weights = sparse_tensor
        return sparse_tensor

    
    def update_weights(self, transformed_feats, node_embeds, epoch):
        transformed_feats = transformed_feats.flatten(1) # 307 * (16993*64)
        # print("transformed_feats.shape in train_mlp after flatten: ", transformed_feats.shape)
        node_embeds = node_embeds.flatten(1) # 307 * (16993*64)

        gram = torch.mm(transformed_feats, transformed_feats.t()).cpu().detach().numpy()
        X = transformed_feats.cpu().detach().numpy().transpose()
        # print("node_embeds.shape in update weights: ", node_embeds.shape)
        y = node_embeds.transpose(0,1)
        # print("y.shape in update weights: ", y.shape)
        # print("transformed_feats.shape in update weights: ", transformed_feats.shape)

        # Xy = torch.mm(transformed_feats, torch.from_numpy(y).to(self.device))
        Xy = torch.mm(transformed_feats, y.to(self.device))
        Xy = Xy.cpu().detach().numpy()
        max_iter = int(42 - epoch//100)
        reg = linear_model.LassoLars(alpha=0, precompute=gram, max_iter=max_iter,
                                    fit_intercept=False, positive=True, eps=1e-8)
        # max_iter: max non zero node selection
        reg.fit(X=X, y=y.cpu().numpy(), Xy=Xy)
        # 307 * (16993*64) -> 307 speed, 16993*64 -> whether converge or not
        # collect the results and index mapping
        local_res = onp.array(reg.coef_) 
        self.weights = local_res


    def train_mlp(self, transformed_feats, node_embeds, epoch):
        # print("transformed_feats.shape in train_mlp: ", transformed_feats.shape)
        transformed_feats = transformed_feats.flatten(1)
        # print("transformed_feats.shape in train_mlp after flatten: ", transformed_feats.shape)
        node_embeds = node_embeds.flatten(1)
        approximate_embeds = torch.mm(torch.tensor(self.weights).to(transformed_feats.device),transformed_feats)
        # print("predicted_embeds.shape in train_mlp: ", approximate_embeds.shape)
        loss = self.loss_func(approximate_embeds.flatten(), node_embeds.flatten())
        smape_diff = self.smape(approximate_embeds.flatten(), node_embeds.flatten())
        mse_diff = self.mean_squared_error(approximate_embeds.flatten(), node_embeds.flatten())
        self.optimizer_sdgnn.zero_grad()
        loss.backward()
        self.optimizer_sdgnn.step()
        print("#"*10+"Train | Epoch:{} | loss: {} | smape embed: {} | mse embed: {} |".format(epoch, loss.item(), smape_diff, mse_diff)+"#"*10)
    
    def train_sdgnn(
            self,
            meta_indices: List[int], meta_index_pad: Optional[int],
            meta_batch_size: int, pinned: List[torch.Tensor], epoch, post_update=False,
            /,
        ) -> Tuple[List[float], TIMECOST]:
            R"""
            Evaluate.
            Mostly used for neural network parameter evaluation.
            """
            #
            timeparts: TIMECOST

            #
            timeparts = {}
            estimates = []
            #
            elapsed = time.time()
            masks_nodesplit_numpy = (
                self.nodesplit_masks(meta_indices, meta_batch_size)
            )
            timeparts["generate"] = []
            elapsed = time.time()
            (masks_nodesplit_ondev,) = (
                transfer([masks_nodesplit_numpy], self.device)
            )
            timeparts["transfer"] = [time.time() - elapsed]

            # If we split data by node, given indices is indeed transductive node
            # indices which has been converted into a mask array before.
            # Thus, we will batch over the full metaset.
            # Otherwise, we only batch over metaset of given meta indices.
            if self.metaspindle == "node":
                #
                batch_indices = list(range(len(self.metaset)))
            else:
                #
                batch_indices = meta_indices

            #
            timeparts["forward"] = []

            transformed_feats_list = []
            target_embeds_list = []
            snn_embeds_list = []

            # \\ cnt = 0
            # \\ self.neuralnet.SEE_EMBEDS = True
            for batch in xitertools.chunked(batch_indices, meta_batch_size):
                # Batchize only nodes of batch graphs.
                elapsed = time.time()
                memory_node_numpy = (
                    self.node_batch(list(batch), meta_index_pad, meta_batch_size)
                )
                cast(List[float], timeparts["generate"]).append(
                    time.time() - elapsed,
                )

                # Node mask need special processing.
                elapsed = time.time()
                (masks_hole_ondev, *memory_node_ondev) = (
                    transfer(memory_node_numpy, self.device)
                )
                cast(List[float], timeparts["transfer"]).append(
                    time.time() - elapsed,
                )
                node_masks_ondev = masks_hole_ondev * masks_nodesplit_ondev

                # Rearange and reshape device memory tensors to fit task
                # requirements.
                (memory_input_ondev, memory_target_ondev) = (
                    self.neuralnet.reshape(
                        pinned, memory_node_ondev, node_masks_ondev,
                    )
                )

                edge_tuples, edge_feats,\
                edge_labels, edge_ranges,\
                edge_times, node_feats,\
                node_labels, node_times,\
                node_masks = memory_input_ondev

                # print("node_feats.shape: ", node_feats.shape)

                snn_embeds = self.get_snn_embedding(node_feats)
                # print("snn_embeds.shape:", snn_embeds.shape)

                snn_embeds = snn_embeds[-1]
                
                # print("snn_embeds[-1].shape:", snn_embeds.shape)
                # print("Now snn_embeds.shape:", snn_embeds.shape)
                # Forward.
                elapsed = time.time()
                node_embeds, node_feats, node_times, node_masks = self.get_node_embedding(*memory_input_ondev)
                # print("node_embeds.shape before view(307, 32, -1): ", node_embeds.shape)
                # print("node_embeds.shape before adding to target_embeds_list: ", node_embeds.view(307, 32, -1).shape)
                target_embeds_list.append(node_embeds.view(self.metaset.num_nodes, snn_embeds.size(1), -1))



                # memory_output_ondev = self.neuralnet.forward(*memory_input_ondev)
                transformed_feats = self.neuralnet.transformation_model(snn_embeds)
                # print("transformed_feats.shape before adding to transformed_feats_list: ", transformed_feats.view(307, 32, -1).shape)

                transformed_feats_list.append(transformed_feats.view(self.metaset.num_nodes, 32, -1))
                # self.update_weights(transformed_feats, node_embeds)

                cast(List[float], timeparts["forward"]).append(
                    time.time() - elapsed,
                )

            # \\ self.neuralnet.SEE_EMBEDS = False

            # print("len(transformed_feats_list): ", len(transformed_feats_list))
            # print("len(target_embeds_list): ", len(target_embeds_list))
            # transformed_feats = torch.tensor(transformed_feats_list)
            # target_embeds = torch.tensor(target_embeds_list)
            transformed_feats = torch.cat(transformed_feats_list, dim=1)
            target_embeds = torch.cat(target_embeds_list, dim=1)
            # transformed_feats = transformed_feats_list[-1]
            # target_embeds = target_embeds_list[-1]
            # print("transformed_feats.shape: ", transformed_feats.shape)
            # print("target_embeds.shape: ", target_embeds.shape)

            # if epoch < 20:
            #     self.train_mlp(transformed_feats, target_embeds, epoch)
            # else:
            self.train_mlp(transformed_feats, target_embeds, epoch)

            if epoch % 20 == 0:
                self.weights_traj.append(torch.tensor(self.weights).unsqueeze(0))
            
            if epoch % 40 == 0:
                self.update_weights(transformed_feats, target_embeds, epoch)

            if post_update:
                self.update_weights(transformed_feats, target_embeds, epoch)
            # Collect mean of all metrics and time costs.
            print("="*10 + 'Non-zero weights entries number: {}'.format(self.weights[self.weights!=0].flatten().shape)+"="*10 )

            return (
                [
                    sum(measure for (_, measure) in record)
                    / sum(size for (size, _) in record)
                    for record in (
                        [list(metric) for metric in xitertools.unzip(estimates)]
                    )
                ],
                timeparts,
            )
        
    def evaluate_sdgnn(
        self,
        meta_indices: List[int], meta_index_pad: Optional[int],
        meta_batch_size: int, pinned: List[torch.Tensor],
        /, saved_results=True
    ) -> Tuple[List[float], TIMECOST]:
        R"""
        Evaluate.
        Mostly used for neural network parameter evaluation.
        """
        #
        timeparts: TIMECOST

        #
        timeparts = {}
        estimates = []
        #
        elapsed = time.time()
        masks_nodesplit_numpy = (
            self.nodesplit_masks(meta_indices, meta_batch_size)
        )
        timeparts["generate"] = []
        elapsed = time.time()
        (masks_nodesplit_ondev,) = (
            transfer([masks_nodesplit_numpy], self.device)
        )
        timeparts["transfer"] = [time.time() - elapsed]

        # If we split data by node, given indices is indeed transductive node
        # indices which has been converted into a mask array before.
        # Thus, we will batch over the full metaset.
        # Otherwise, we only batch over metaset of given meta indices.
        if self.metaspindle == "node":
            #
            batch_indices = list(range(len(self.metaset)))
        else:
            #
            batch_indices = meta_indices

        #
        timeparts["forward"] = []
        approximate_embeds_list = []
        node_embeds_list = []
        smape_diff_list = []
        mse_diff_list = []
        preds = []
        ground_truths = []

        # \\ cnt = 0
        # \\ self.neuralnet.SEE_EMBEDS = True
        for batch in xitertools.chunked(batch_indices, meta_batch_size):
            # Batchize only nodes of batch graphs.
            elapsed = time.time()
            memory_node_numpy = (
                self.node_batch(list(batch), meta_index_pad, meta_batch_size)
            )
            cast(List[float], timeparts["generate"]).append(
                time.time() - elapsed,
            )

            # Node mask need special processing.
            elapsed = time.time()
            (masks_hole_ondev, *memory_node_ondev) = (
                transfer(memory_node_numpy, self.device)
            )
            cast(List[float], timeparts["transfer"]).append(
                time.time() - elapsed,
            )
            node_masks_ondev = masks_hole_ondev * masks_nodesplit_ondev

            # Rearange and reshape device memory tensors to fit task
            # requirements.
            (memory_input_ondev, memory_target_ondev) = (
                self.neuralnet.reshape(
                    pinned, memory_node_ondev, node_masks_ondev,
                )
            )

            
            edge_tuples, edge_feats,\
            edge_labels, edge_ranges,\
            edge_times, node_feats,\
            node_labels, node_times,\
            node_masks = memory_input_ondev

            # Forward.
            elapsed = time.time()
            memory_output_ondev = self.neuralnet.forward(*memory_input_ondev)
            cast(List[float], timeparts["forward"]).append(
                time.time() - elapsed,
            )

            # print("node_feats.shape: ", node_feats.shape)

            snn_embeds = self.get_snn_embedding(node_feats)
            # print("snn_embeds.shape:", snn_embeds.shape)

            snn_embeds = snn_embeds[-1]
            # print("snn_embeds[-1].shape:", snn_embeds.shape)
            # print("Now snn_embeds.shape:", snn_embeds.shape)
            # Forward.
            elapsed = time.time()
            # memory_output_ondev = self.neuralnet.forward(*memory_input_ondev)
            transformed_feats = self.neuralnet.transformation_model(snn_embeds)
            transformed_feats = transformed_feats.view(self.metaset.num_nodes,snn_embeds.size(1),-1)
            transformed_feats = transformed_feats.view(self.metaset.num_nodes,-1)
            approximate_embeds = torch.mm(torch.tensor(self.weights).to(transformed_feats.device),transformed_feats)

            approximate_embeds = approximate_embeds.view(-1,snn_embeds.size(1),snn_embeds.size(1)*2)
            approximate_embeds = approximate_embeds.view(-1,approximate_embeds.size(2))
            node_embeds, node_feats, node_times, node_masks = self.get_node_embedding(*memory_input_ondev)
            smape_diff = self.smape(approximate_embeds.flatten(), node_embeds.flatten())     
            smape_diff_list.append(smape_diff) 
            mse_diff = self.mean_squared_error(approximate_embeds.flatten(), node_embeds.flatten())
            mse_diff_list.append(mse_diff)
            
            approximate_embeds_list.append(approximate_embeds)
            node_embeds_list.append(node_embeds)
            memory_output_ondev = self.neuralnet.mlp(approximate_embeds)
            memory_output_ondev = [memory_output_ondev.squeeze(-1)]
            # print(memory_target_ondev[0].shape)
            # print(memory_output_ondev.shape)
            # Performance metrics.
            estimates.append(
                self.neuralnet.metrics(
                    *memory_output_ondev, *memory_target_ondev,
                ),
            )
            preds.append(memory_output_ondev[0])
            ground_truths.append(memory_target_ondev[0].squeeze(-1))
            # \\ cnt += 1
            # \\ if cnt == 2:
            # \\     #
            # \\     self.neuralnet.SEE_EMBEDS = False
        # \\ self.neuralnet.SEE_EMBEDS = False
        smape_diff_avg = torch.mean(torch.tensor(smape_diff_list))
        mse_diff_avg = torch.mean(torch.tensor(mse_diff_list))
        print("#"*10+"Eval | smape embed: {} | mse embed: {} |".format(smape_diff_avg, mse_diff_avg)+"#"*10)
        # Collect mean of all metrics and time costs.
        preds = torch.cat(preds)
        ground_truths = torch.cat(ground_truths)
        # Append the new data to the CSV file
        if self.save_measurement:
            metrics =                 [
                    sum(measure for (_, measure) in record)
                    / sum(size for (size, _) in record)
                    for record in (
                        [list(metric) for metric in xitertools.unzip(estimates)]
                    )
                ]
            with open(self.filename, 'a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow([metrics[0],metrics[1], metrics[2], metrics[-1], mse_diff_avg.item(), smape_diff_avg.item()])
        if saved_results:
            return (
                [
                    sum(measure for (_, measure) in record)
                    / sum(size for (size, _) in record)
                    for record in (
                        [list(metric) for metric in xitertools.unzip(estimates)]
                    )
                ],
                timeparts, preds, ground_truths
            )
        else:
            return (
                [
                    sum(measure for (_, measure) in record)
                    / sum(size for (size, _) in record)
                    for record in (
                        [list(metric) for metric in xitertools.unzip(estimates)]
                    )
                ],
                timeparts
            )

    def continue_train_low_level_mlp(            
            self,
            meta_indices: List[int], meta_index_pad: Optional[int],
            meta_batch_size: int, pinned: List[torch.Tensor], epoch,
            /,):
        R"""
        Evaluate.
        Mostly used for neural network parameter evaluation.
        """
        #
        timeparts: TIMECOST

        #
        timeparts = {}
        estimates = []
        #
        elapsed = time.time()
        masks_nodesplit_numpy = (
            self.nodesplit_masks(meta_indices, meta_batch_size)
        )
        timeparts["generate"] = []
        elapsed = time.time()
        (masks_nodesplit_ondev,) = (
            transfer([masks_nodesplit_numpy], self.device)
        )
        timeparts["transfer"] = [time.time() - elapsed]

        # If we split data by node, given indices is indeed transductive node
        # indices which has been converted into a mask array before.
        # Thus, we will batch over the full metaset.
        # Otherwise, we only batch over metaset of given meta indices.
        if self.metaspindle == "node":
            #
            batch_indices = list(range(len(self.metaset)))
        else:
            #
            batch_indices = meta_indices

        #
        timeparts["forward"] = []

        transformed_feats_list = []
        target_embeds_list = []
        snn_embeds_list = []

        for param in self.neuralnet.mlp.parameters():
            param.requires_grad = True

        for param in self.neuralnet.transformation_model.parameters():
            param.requires_grad = False
        
        self.neuralnet.mlp.train()
        # \\ cnt = 0
        # \\ self.neuralnet.SEE_EMBEDS = True
        for batch in xitertools.chunked(batch_indices, meta_batch_size):
            # Batchize only nodes of batch graphs.
            elapsed = time.time()
            memory_node_numpy = (
                self.node_batch(list(batch), meta_index_pad, meta_batch_size)
            )
            cast(List[float], timeparts["generate"]).append(
                time.time() - elapsed,
            )

            # Node mask need special processing.
            elapsed = time.time()
            (masks_hole_ondev, *memory_node_ondev) = (
                transfer(memory_node_numpy, self.device)
            )
            cast(List[float], timeparts["transfer"]).append(
                time.time() - elapsed,
            )
            node_masks_ondev = masks_hole_ondev * masks_nodesplit_ondev

            # Rearange and reshape device memory tensors to fit task
            # requirements.
            (memory_input_ondev, memory_target_ondev) = (
                self.neuralnet.reshape(
                    pinned, memory_node_ondev, node_masks_ondev,
                )
            )

            
            edge_tuples, edge_feats,\
            edge_labels, edge_ranges,\
            edge_times, node_feats,\
            node_labels, node_times,\
            node_masks = memory_input_ondev

            # Forward.
            elapsed = time.time()
            memory_output_ondev = self.neuralnet.forward(*memory_input_ondev)
            cast(List[float], timeparts["forward"]).append(
                time.time() - elapsed,
            )

            # print("node_feats.shape: ", node_feats.shape)

            snn_embeds = self.get_snn_embedding(node_feats)
            # print("snn_embeds.shape:", snn_embeds.shape)

            snn_embeds = snn_embeds[-1]
            # print("snn_embeds[-1].shape:", snn_embeds.shape)
            # print("Now snn_embeds.shape:", snn_embeds.shape)
            # Forward.
            elapsed = time.time()
            # memory_output_ondev = self.neuralnet.forward(*memory_input_ondev)
            transformed_feats = self.neuralnet.transformation_model(snn_embeds)
            transformed_feats = transformed_feats.view(self.metaset.num_nodes,snn_embeds.size(1),-1)
            transformed_feats = transformed_feats.view(self.metaset.num_nodes,-1)
            approximate_embeds = torch.mm(torch.tensor(self.weights).to(transformed_feats.device),transformed_feats)
            approximate_embeds = approximate_embeds.view(-1,snn_embeds.size(1),snn_embeds.size(1)*2)
            approximate_embeds = approximate_embeds.view(-1,approximate_embeds.size(2))
            memory_output_ondev = self.neuralnet.mlp(approximate_embeds)
            memory_output_ondev = [memory_output_ondev.squeeze(-1)]
            loss = self.loss_func(memory_output_ondev[0].flatten(), memory_target_ondev[0].flatten())
            self.optimizer_mlp.zero_grad()
            loss.backward()
            self.optimizer_mlp.step()
        print("#"*10+"Epoch:{} low levle mlp loss: {}".format(epoch, loss.item())+"#"*10)
            